{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from utils import get_score, normalize\n",
    "result_dir='/home/dyj/'\n",
    "sigmoid = torch.sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "label = torch.load('/home/dyj/label.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_weighted_score(resmat, loss_weight):\n",
    "    return sigmoid(resmat) * torch.sqrt(1-loss_weight+loss_weight.mean()).expand_as(resmat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_loss_weight(logit, label):\n",
    "    class_num = logit.size(1)\n",
    "    predict_label_list = [list(ii) for ii in logit.topk(5, 1)[1]]\n",
    "    marked_label_list = [list(np.where(ii.numpy()==1)[0]) for ii in label]\n",
    "    sample_per_class = torch.zeros(class_num)\n",
    "    error_per_class = torch.zeros(class_num)\n",
    "    for predict_labels, marked_labels in zip(predict_label_list, marked_label_list):\n",
    "        for true_label in marked_labels:\n",
    "            sample_per_class[true_label] += 1\n",
    "            if true_label not in predict_labels:\n",
    "                error_per_class[true_label] += 1\n",
    "    return error_per_class / sample_per_class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn1 = torch.load('/home/dyj/TextCNN1_2017-07-27#10:15:20_res.pt')\n",
    "cnn1_loss_weight = torch.load(result_dir+'TextCNN1_loss_weight.pt')\n",
    "logit1_ = get_weighted_score(cnn1, cnn1_loss_weight)\n",
    "del cnn1, cnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn2 = torch.load('/home/dyj/TextCNN2_2017-07-27#10:32:21_res.pt')\n",
    "cnn2_loss_weight = torch.load('/home/dyj/TextCNN2_2017-07-27#10:32:21_loss_weight.pt')\n",
    "logit2_ = get_weighted_score(cnn2, cnn2_loss_weight)\n",
    "del cnn2, cnn2_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rnn1 = torch.load('/home/dyj/RNN1_2017-07-27#10:48:05_res.pt')\n",
    "rnn1_loss_weight = torch.load(result_dir+'RNN1_loss_weight.pt')\n",
    "logit3_ = get_weighted_score(rnn1, rnn1_loss_weight)\n",
    "del rnn1, rnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rnn2 = torch.load('/home/dyj/RNN2_2017-07-27#10:41:03_res.pt')\n",
    "rnn2_loss_weight = torch.load(result_dir+'RNN2_loss_weight.pt')\n",
    "logit4_ = get_weighted_score(rnn2, rnn2_loss_weight)\n",
    "del rnn2, rnn2_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rcnn1 = torch.load(result_dir + 'RCNN1_2017-07-27#11:01:07_res.pt')\n",
    "rcnn1_loss_weight = torch.load('/home/dyj/RCNN1_loss_weight.pt')\n",
    "logit5_ = get_weighted_score(rcnn1, rcnn1_loss_weight)\n",
    "del rcnn1, rcnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rcnncha = torch.load('/home/dyj/RCNNcha_2017-07-27#16:19:23_res.pt')\n",
    "rcnncha_loss_weight = torch.load('/home/dyj/RCNNcha_loss_weight.pt')\n",
    "logit6_ = get_weighted_score(rcnncha, rcnncha_loss_weight)\n",
    "del rcnncha, rcnncha_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fasttext1 = torch.load('/home/dyj/FastText1_2017-07-29#10:31:43_res.pt')\n",
    "fasttext1_loss_weight = torch.load('/home/dyj/FastText1_loss_weight.pt')\n",
    "logit7_ = get_weighted_score(fasttext1, fasttext1_loss_weight)\n",
    "del fasttext1, fasttext1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fasttext1_char = torch.load('/home/dyj/FastText1_char.pt')\n",
    "fasttext1_loss_weight = torch.load('snapshots/FastText/layer_2_loss_weight_char.pt')\n",
    "logit8_ = get_weighted_score(fasttext1_char, fasttext1_loss_weight)\n",
    "del fasttext1_char, fasttext1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn1_augment = torch.load('/home/dyj/TextCNN1_augment.pt')\n",
    "cnn1_loss_weight = get_loss_weight(cnn1_augment, label)\n",
    "logit9_ = get_weighted_score(cnn1_augment, cnn1_loss_weight)\n",
    "del cnn1_augment, cnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn1_word_char = torch.load('/home/dyj/TextCNN1_word_char.pt')\n",
    "cnn1_loss_weight = get_loss_weight(cnn1_word_char, label)\n",
    "logit10_ = get_weighted_score(cnn1_word_char, cnn1_loss_weight)\n",
    "del cnn1_word_char, cnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn1_word_top1 = torch.load('/home/dyj/TextCNN1_word_top1.pt')\n",
    "cnn1_loss_weight = get_loss_weight(cnn1_word_top1, label)\n",
    "logit11_ = get_weighted_score(cnn1_word_top1, cnn1_loss_weight)\n",
    "del cnn1_word_top1, cnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn1_word_top3 = torch.load('/home/dyj/TextCNN1_word_top3.pt')\n",
    "cnn1_loss_weight = get_loss_weight(cnn1_word_top3, label)\n",
    "logit12_ = get_weighted_score(cnn1_word_top3, cnn1_loss_weight)\n",
    "del cnn1_word_top3, cnn1_loss_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit_ = (logit1_ + logit2_ + logit3_ + logit4_ + logit5_ + logit6_ + logit7_ + logit8_ + logit9_ + logit10_ + logit11_ + logit12_) / 12\n",
    "get_score(logit_, label)\n",
    "del logit1_, logit2_, logit3_, logit4_, logit5_, logit6_, logit7_, logit8_, logit9_, logit10_, logit11_, logit12_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rnn10 = torch.load(result_dir + 'RNN10_cal_res.pt')\n",
    "rnn10_loss_weight = torch.load('snapshots/RNN/layer_11_loss_weight_3.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit1 = sigmoid(rnn10/10) * torch.sqrt(1-rnn10_loss_weight+rnn10_loss_weight.mean()).expand_as(rnn10)\n",
    "del rnn10, rnn10_loss_weight\n",
    "# get_score(logit1, label)\n",
    "# 0.42681"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn10_char = torch.load(result_dir + 'TextCNN10_char.pt')\n",
    "cnn10_char_loss_weight = get_loss_weight(cnn10_char, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit2 = sigmoid(cnn10_char/10) * torch.sqrt(1-cnn10_char_loss_weight+cnn10_char_loss_weight.mean()).expand_as(cnn10_char)\n",
    "del cnn10_char, cnn10_char_loss_weight\n",
    "# get_score(logit2, label)\n",
    "# 0.41843"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn10_top1 = torch.load(result_dir + 'TextCNN10_top1.pt')\n",
    "cnn10_top1_loss_weight = get_loss_weight(cnn10_top1, label)\n",
    "# torch.save(cnn10_top1_loss_weight, 'snapshots/TextCNN/layer_11_loss_weight_top1_top5.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit3 = sigmoid(cnn10_top1/10) * torch.sqrt(1-cnn10_top1_loss_weight+cnn10_top1_loss_weight.mean()).expand_as(cnn10_top1)\n",
    "del cnn10_top1, cnn10_top1_loss_weight\n",
    "# get_score(logit3, label)\n",
    "# 0.42705"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn10_top1_char = torch.load(result_dir + 'TextCNN10_top1_char.pt')\n",
    "cnn10_top1_char_loss_weight = get_loss_weight(cnn10_top1_char, label)\n",
    "# torch.save(cnn10_top1_char_loss_weight, 'snapshots/TextCNN/layer_11_loss_weight_top1_char_top5.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit4 = sigmoid(cnn10_top1_char/10) * torch.sqrt(1-cnn10_top1_char_loss_weight+cnn10_top1_char_loss_weight.mean()).expand_as(cnn10_top1_char)\n",
    "del cnn10_top1_char, cnn10_top1_char_loss_weight\n",
    "# get_score(logit4, label)\n",
    "# 0.41850"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fasttext10 = torch.load('/home/dyj/FastText10_res.pt')\n",
    "fasttext10_loss_weight = torch.load('/home/dyj/FastText10_loss_weight.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit5 = sigmoid(fasttext10/10) * torch.sqrt(1-fasttext10_loss_weight+fasttext10_loss_weight.mean()).expand_as(fasttext10)\n",
    "del fasttext10, fasttext10_loss_weight\n",
    "# get_score(logit5, label)\n",
    "# 0.41932"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cnn5 = torch.load('/mnt/result/results/TextCNN5_12h.pt')\n",
    "cnn5_loss_weight = torch.load('snapshots/TextCNN1/layer_6_loss_weight_3.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit6 = sigmoid(cnn5/5) * torch.sqrt(1-cnn5_loss_weight+cnn5_loss_weight.mean()).expand_as(cnn5)\n",
    "del cnn5, cnn5_loss_weight\n",
    "# get_score(logit6, label)\n",
    "# 0.42417"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rnn1_char = torch.load('/mnt/result/results/RNN1_char.pt')\n",
    "rnn1_char_loss_weight = torch.load('snapshots/RNN/layer_2_loss_weight_char.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit7 = sigmoid(rnn1_char/1) * torch.sqrt(1-rnn1_char_loss_weight+rnn1_char_loss_weight.mean()).expand_as(rnn1_char)\n",
    "del rnn1_char, rnn1_char_loss_weight\n",
    "# get_score(logit7, label)\n",
    "# 0.40001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2803 + logit2 * 0.1410 + logit3 * 0.1929 + logit4 * 0.1499 + logit5 * 0.0230 + logit6 * 0.1081 + logit7 * 0.0015\n",
    "get_score(logit, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.4893289791299433, 0.6086568286493091, 0.4320764467955468)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logit = logit1 * 0.29 + logit2 * 0.142 + logit3 * 0.22 + logit4 * 0.15 + logit5 * 0.023 + logit6 * 0.11 + logit7 * 0.002\n",
    "get_score(logit, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.4894996710619259, 0.6087052496959443, 0.43211521440350509)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logit = logit1 * 0.29 + logit2 * 0.14 + logit3 * 0.19 + logit4 * 0.15 + logit5 * 0.02 + logit6 * 0.16 + logit7 * 0.005\n",
    "get_score(logit, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.29 + logit2 * 0.14 + logit3 * 0.19 + logit4 * 0.15 + logit5 * 0.02 + logit6 * 0.16 + logit7 * 0.005 + \\\n",
    "        logit5_*0.005\n",
    "get_score(logit, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.29 + logit2 * 0.14 + logit3 * 0.19 + logit4 * 0.15 + logit5 * 0.02 + logit6 * 0.16# + logit7 * 0.00\n",
    "get_score(logit, label)\n",
    "# (1.4895235390079529, 0.6086924323600703, 0.43211076380530711)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "get_score(logit1, label), get_score(logit3, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "get_score(logit2, label), get_score(logit4, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "a = torch.from_numpy(np.load('/mnt/result/results/ncnnp7.npy')).float()\n",
    "a.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "'''\n",
    "logit1:  rnn10_word_top5 0.42681\n",
    "logit2:  cnn10_char_top5 0.41843\n",
    "logit3:  cnn10_word_top1 0.42705\n",
    "logit4:  cnn10_char_top1 0.41850\n",
    "logit5:  fasttext10      0.41932\n",
    "logit6:  cnn4            0.42331\n",
    "xxlogit7:  cnn1_augment    0.41449xx\n",
    "xxlogit8:  cnn5_shuffle    0.42301xx\n",
    "logit9:  cnn1_word_char  0.41004\n",
    "xxlogit10: rnn10_cnn7      0.42994xx\n",
    "xxlogit11: rnn10_cnn8      0.43003xx\n",
    "xxlogit12: rnn10_cnn9      0.43009xx\n",
    "logit13: cnn1_word_top1    0.41237\n",
    "logit14: cnn1_word_top3    0.41282\n",
    "xxlogit15: rcnn              0.40916xx\n",
    "xxlogit16: rcnncha           0.40780xx\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.306 + logit2 * 0.18 + logit3 * 0.14 + logit4 * 0.13 + logit5 * 0.058 + logit6 * 0.12 + logit9 * 0.05 + \\\n",
    "        logit13 * 0.01 + logit14 * 0.02\n",
    "print get_score(logit, label)\n",
    "del logit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.306 + logit2 * 0.18 + logit3 * 0.14 + logit4 * 0.13 + logit5 * 0.058 + logit6 * 0.12 + logit9 * 0.05 + \\\n",
    "        logit13 * 0.01 + logit14 * 0.02 + logit7 * -0.009\n",
    "print get_score(logit, label)\n",
    "del logit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.30 + logit2 * 0.14 + logit3 * 0.19 + logit4 * 0.15 + logit5 * 0.02 + logit6 * 0.1\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.489309057763871, 0.6086297698291306, 0.43206113405662766)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# logit = logit1 * 0.30 + logit2 * 0.14 + logit3 * 0.19 + logit4 * 0.15 + logit5 * 0.02 + logit6 * 0.1\n",
    "index = torch.from_numpy(np.random.permutation(logit.size(0))[:100000])\n",
    "get_score(logit[index], label[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2069 + logit2 * 0.0995 + logit3 * 0.1710 + logit4 * 0.1426 + logit5 * 0.0157 + logit6 * 0.0964 + logit10 * 0.1622\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4891288461364194, 0.6085015964703905, 0.4319813737533425)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2089 + logit2 * 0.1011 + logit3 * 0.1650 + logit4 * 0.1432 + logit5 * 0.0161 + logit6 * 0.0950 + logit11 * 0.1666\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4891470589385698, 0.6084716893533512, 0.43196783370941988)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2113 + logit2 * 0.1021 + logit3 * 0.1610 + logit4 * 0.1423 + logit5 * 0.0153 + logit6 * 0.0940 + logit12 * 0.1694\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4891114906638721, 0.6084802342439338, 0.43196914723446073)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.1877 + logit2 * 0.1046 + logit3 * 0.1358 + logit4 * 0.1378 + logit5 * 0.0047 + logit6 * 0.1005 + \\\n",
    "        logit7 * 0.0875 + logit9 * -0.0113 + logit10 * 0.1492\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4888985429296957, 0.6084147234161333, 0.43191821161249683)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.1899 + logit2 * 0.1053 + logit3 * 0.1323 + logit4 * 0.1378 + logit5 * 0.0051 + logit6 * 0.0983 + \\\n",
    "        logit7 * 0.0869 + logit9 * -0.0117 + logit11 * 0.1509\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4888500077740734, 0.6084019060802593, 0.43190766765478322)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.1923 + logit2 * 0.1065 + logit3 * 0.1288 + logit4 * 0.1389 + logit5 * 0.0061 + logit6 * 0.0987 + \\\n",
    "        logit7 * 0.0861 + logit9 * -0.0107 + logit12 * 0.1520\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4889230059676828, 0.6084275407520073, 0.43192672980053676)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2802 + logit2 * 0.1388 + logit3 * 0.1672 + logit4 * 0.1504 + logit5 * 0.0124 + logit6 * 0.1103 + \\\n",
    "        logit7 * 0.1099 + logit8 * -0.0607 + logit9 * -0.0091\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4892614641474382, 0.608448902978464, 0.43196597505027462)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2722 + logit2 * 0.1365 + logit3 * 0.1506 + logit4 * 0.1480 + logit5 * 0.0048 + logit6 * 0.1024 + \\\n",
    "        logit7 * 0.0918 + logit9 * -0.0105\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4890620566459902, 0.6085129896578341, 0.43198149476099751)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.3098 + logit2 * 0.1439 + logit3 * 0.1995 + logit4 * 0.1555 + logit5 * 0.0242 + logit6 * 0.1062 + \\\n",
    "        logit8 * -0.0338 + logit9 * -0.0053\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.48937243263684, 0.6085813487824955, 0.43204206494977304)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2840 + logit2 * 0.1366 + logit3 * 0.1678 + logit4 * 0.1477 + logit5 * 0.0108 + logit6 * 0.1091 + \\\n",
    "        logit7 * 0.1080 + logit8 * -0.0671\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4892897725974128, 0.6084759617986425, 0.43198199490039124)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.2749 + logit2 * 0.1346 + logit3 * 0.1491 + logit4 * 0.1470 + logit5 * 0.0046 + logit6 * 0.1031 + \\\n",
    "        logit7 * 0.0879\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4891557620183564, 0.6085742280403432, 0.4320202421650991)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.3127 + logit2 * 0.1415 + logit3 * 0.2013 + logit4 * 0.1546 + logit5 * 0.0242 + logit6 * 0.1042 + \\\n",
    "        logit8 * -0.0419\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4893664301230018, 0.6085671072981911, 0.43203438236709918)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = logit1 * 0.3026 + logit2 * 0.1417 + logit3 * 0.1899 + logit4 * 0.1538 + logit5 * 0.0204 + logit6 * 0.0968 + \\\n",
    "        logit9 * -0.0077\n",
    "print get_score(logit, label)\n",
    "del logit\n",
    "# (1.4892167342178562, 0.6085656831497606, 0.43202106744444457)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%cd utils\n",
    "from resmat import search_model_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "models = [logit1, logit2, logit3, logit4, logit5]\n",
    "logit = torch.stack(models, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "init_weight = torch.Tensor([0.4, 0.18, 0.22, 0.1, 0.1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cur_weight = torch.load('../snapshots/Stack/model_weight.pt')\n",
    "cur_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "search_model_weight(logit, label, init_weight, 10, 0.001, '../snapshots/Stack/model_weight.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "layer_11 = torch.load('snapshots/TextCNN/layer_11_cal_res_char.pt')\n",
    "layer_12 = torch.load('snapshots/TextCNN/layer_12_cal_res_char.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "layer_12_weighted = (layer_11 / 11 * 5.5 + (layer_12-layer_11)) / 6.5 * 12\n",
    "layer_12_loss_weight = get_loss_weight(layer_12_weighted, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print get_score(layer_12_weighted, label)\n",
    "get_score(sigmoid(layer_12_weighted/12) * torch.sqrt(1-layer_12_loss_weight+layer_12_loss_weight.mean()).expand_as(layer_12_weighted), label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.save(layer_12_weighted, 'snapshots/TextCNN/layer_12_weighted_cal_res_char.pt')\n",
    "torch.save(layer_12_loss_weight, 'snapshots/TextCNN/layer_13_weighted_loss_weight_char.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rnn_cnn2_test = torch.load('results/RNN10_CNN2_test_res.pt')\n",
    "cnn3_test = torch.load('results/CNN3_2017-08-07#12:41:45_test_res.pt')\n",
    "cnn2_test = torch.load('results/CNN2_2017-08-07#12:53:50_test_res.pt')\n",
    "torch.save(rnn_cnn2_test+cnn3_test+cnn2_test, '/home/dyj/RNN10_CNN7_test_res.pt')\n",
    "torch.save(rnn_cnn2_test+cnn3_test+cnn2_test, 'results/RNN10_CNN7_test_res.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logit = sigmoid(cnn1) * torch.sqrt(1-cnn1_loss_weight+cnn1_loss_weight.mean()).expand_as(cnn1) * 0 + \\\n",
    "        sigmoid(rnn1) * torch.sqrt(1-rnn1_loss_weight+rnn1_loss_weight.mean()).expand_as(rnn1) * 0. + \\\n",
    "        sigmoid(rcnn1) * torch.sqrt(1-rcnn1_loss_weight+rcnn1_loss_weight.mean()).expand_as(rcnn1) * 0 + \\\n",
    "        sigmoid(rcnncha) * torch.sqrt(1-rcnncha_loss_weight+rcnncha_loss_weight.mean()).expand_as(rcnncha) * 0 + \\\n",
    "        sigmoid(fasttext1_char) * torch.sqrt(1-fasttext1_loss_weight+fasttext1_loss_weight.mean()).expand_as(fasttext1_char) * 0 + \\\n",
    "        sigmoid(fasttext4/4) * torch.sqrt(1-fasttext4_loss_weight+fasttext4_loss_weight.mean()).expand_as(fasttext4) * 0 + \\\n",
    "        sigmoid(fasttext10/10) * torch.sqrt(1-fasttext10_loss_weight+fasttext10_loss_weight.mean()).expand_as(fasttext10) * 0.0469 + \\\n",
    "        sigmoid(rnn10_cnn7/17) * torch.sqrt(1-rnn10_cnn7_loss_weight+rnn10_cnn7_loss_weight.mean()).expand_as(rnn10_cnn7) * 0.6604 + \\\n",
    "        sigmoid(cnn4/4) * torch.sqrt(1-cnn4_loss_weight+cnn4_loss_weight.mean()).expand_as(cnn4) * 0.1536"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "get_loss_weight??\n",
    "rnn10_cnn8_loss_weight_new = get_loss_weight(rnn10_cnn8, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.save(rnn10_cnn8_loss_weight_new, 'snapshots/TextCNN/layer_19_loss_weight_top1.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "layer_18_loss_weight_top3 = torch.load('snapshots/TextCNN/layer_18_loss_weight_top3.pt')\n",
    "layer_19_loss_weight_top3 = torch.load('snapshots/TextCNN/layer_19_loss_weight_top3.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print layer_18_loss_weight_top3.mean()\n",
    "layer_19_loss_weight_top3.mean()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
